# -*- coding: utf-8 -*-
# File: format.py


import numpy as np
import os
import six
from six.moves import range
import struct

from ..utils import logger
from ..utils.argtools import log_once
from ..utils.compatible_serialize import loads
from ..utils.develop import create_dummy_class  # noqa
from ..utils.loadcaffe import get_caffe_pb
from ..utils.timer import timed_operation
from ..utils.utils import get_tqdm
from .base import DataFlowReentrantGuard, RNGDataFlow
from .common import MapData
from .base import DataFlow

__all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder',
           'CaffeLMDB', 'SVMLightData',
           'LMDBDataPoint', 'LMDBDataPointIndexed',
           'TFRecordData', 'BinaryData', 'NPZData']

"""
Adapters for different data format.
"""


class HDF5Data(RNGDataFlow):
    """
    Zip data from different paths in an HDF5 file.

    Warning:
        The current implementation will load all data into memory. (TODO)
    """
# TODO

    def __init__(self, filename, data_paths, shuffle=True):
        """
        Args:
            filename (str): h5 data file.
            data_paths (list): list of h5 paths to zipped.
                For example `['images', 'labels']`.
            shuffle (bool): shuffle all data.
        """
        self.f = h5py.File(filename, 'r')
        logger.info("Loading {} to memory...".format(filename))
        self.dps = [self.f[k].value for k in data_paths]
        lens = [len(k) for k in self.dps]
        assert all([k == lens[0] for k in lens])
        self._size = lens[0]
        self.shuffle = shuffle

    def __len__(self):
        return self._size

    def __iter__(self):
        idxs = list(range(self._size))
        if self.shuffle:
            self.rng.shuffle(idxs)
        for k in idxs:
            yield [dp[k] for dp in self.dps]


class LMDBData(RNGDataFlow):
    """
    Read a LMDB database and produce (k,v) raw bytes pairs.
    The raw bytes are usually not what you're interested in.
    You might want to use
    :class:`LMDBDataDecoder` or apply a
    mapper function after :class:`LMDBData`.
    """
    def __init__(self, lmdb_path, shuffle=True, keys=None):
        """
        Args:
            lmdb_path (str): a directory or a file.
            shuffle (bool): shuffle the keys or not.
            keys (list[str] or str): list of str as the keys, used only when shuffle is True.
                It can also be a format string e.g. ``{:0>8d}`` which will be
                formatted with the indices from 0 to *total_size - 1*.

                If not given, it will then look in the database for ``__keys__`` which
                :func:`LMDBSerializer.save` used to store the list of keys.
                If still not found, it will iterate over the database to find
                all the keys.
        """
        self._lmdb_path = lmdb_path
        self._shuffle = shuffle

        self._open_lmdb()
        self._size = self._txn.stat()['entries']
        self._set_keys(keys)
        logger.info("Found {} entries in {}".format(self._size, self._lmdb_path))
        self._guard = DataFlowReentrantGuard()

    def _set_keys(self, keys=None):
        def find_keys(txn, size):
            logger.warn("Traversing the database to find keys is slow. Your should specify the keys.")
            keys = []
            with timed_operation("Loading LMDB keys ...", log_start=True), \
                    get_tqdm(total=size) as pbar:
                for k in self._txn.cursor():
                    assert k[0] != b'__keys__'
                    keys.append(k[0])
                    pbar.update()
            return keys

        self.keys = self._txn.get(b'__keys__')
        if self.keys is not None:
            self.keys = loads(self.keys)
            self._size -= 1     # delete this item

        if self._shuffle:   # keys are necessary when shuffle is True
            if keys is None:
                if self.keys is None:
                    self.keys = find_keys(self._txn, self._size)
            else:
                # check if key-format like '{:0>8d}' was given
                if isinstance(keys, six.string_types):
                    self.keys = map(lambda x: keys.format(x), list(np.arange(self._size)))
                else:
                    self.keys = keys

    def _open_lmdb(self):
        self._lmdb = lmdb.open(self._lmdb_path,
                               subdir=os.path.isdir(self._lmdb_path),
                               readonly=True, lock=False, readahead=True,
                               map_size=1099511627776 * 2, max_readers=100)
        self._txn = self._lmdb.begin()

    def reset_state(self):
        self._lmdb.close()
        super(LMDBData, self).reset_state()
        self._open_lmdb()

    def __len__(self):
        return self._size

    def __iter__(self):
        with self._guard:
            if not self._shuffle:
                c = self._txn.cursor()
                while c.next():
                    k, v = c.item()
                    if k != b'__keys__':
                        yield [k, v]
            else:
                self.rng.shuffle(self.keys)
                for k in self.keys:
                    v = self._txn.get(k)
                    yield [k, v]


class LMDBDataDecoder(MapData):
    """ Read a LMDB database with a custom decoder and produce decoded outputs."""
    def __init__(self, lmdb_data, decoder):
        """
        Args:
            lmdb_data: a :class:`LMDBData` instance.
            decoder (k,v -> dp | None): a function taking k, v and returning a datapoint,
                or return None to discard.
        """
        def f(dp):
            return decoder(dp[0], dp[1])
        super(LMDBDataDecoder, self).__init__(lmdb_data, f)


class LMDBDataPoint(MapData):
    """
    Read a LMDB file and produce deserialized datapoints.
    It **only** accepts the database produced by
    :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`,
    which uses :func:`tensorpack.utils.serialize.dumps` for serialization.

    Example:
        .. code-block:: python

            ds = LMDBDataPoint("/data/ImageNet.lmdb", shuffle=False)  # read and decode

            # The above is equivalent to:
            ds = LMDBData("/data/ImageNet.lmdb", shuffle=False)  # read
            ds = LMDBDataPoint(ds)  # decode
            # Sometimes it makes sense to separate reading and decoding
            # to be able to make decoding parallel.
    """

    def __init__(self, *args, **kwargs):
        """
        Args:
            args, kwargs: Same as in :class:`LMDBData`.

        In addition, args[0] can be a :class:`LMDBData` instance.
        In this case args[0] has to be the only argument.
        """
        if isinstance(args[0], DataFlow):
            ds = args[0]
            assert len(args) == 1 and len(kwargs) == 0, \
                "No more arguments are allowed if LMDBDataPoint is called with a LMDBData instance!"
        else:
            ds = LMDBData(*args, **kwargs)

        def f(dp):
            return loads(dp[1])
        super(LMDBDataPoint, self).__init__(ds, f)


class LMDBDataPointIndexed(MapData):
    """ Read a LMDB file and produce deserialized values.
        This can work with :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. 
        
        The input DS can be processed already (with join, shuffle, etc), so the index
        of key and val are specified. 

        Further more, there could be multiple LMDB data fused together before local shuffle.
    """
    def __init__(self, ds, index=1):
        def f(dp):
            # replace key and serialized val with val
            dp[index-1:index+1] = loads(dp[index])
            return dp
        super(LMDBDataPointIndexed, self).__init__(ds, f)


def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
    """
    Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf.
    Produces datapoints of the format: [HWC image, label].

    Note that Caffe LMDB format is not efficient: it stores serialized raw
    arrays rather than JPEG images.

    Args:
        lmdb_path, shuffle, keys: same as :class:`LMDBData`.

    Returns:
        a :class:`LMDBDataDecoder` instance.

    Example:
        .. code-block:: python

            ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
    """

    cpb = get_caffe_pb()
    lmdb_data = LMDBData(lmdb_path, shuffle, keys)

    def decoder(k, v):
        try:
            datum = cpb.Datum()
            datum.ParseFromString(v)
            img = np.fromstring(datum.data, dtype=np.uint8)
            img = img.reshape(datum.channels, datum.height, datum.width)
        except Exception:
            log_once("Cannot read key {}".format(k), 'warn')
            return None
        return [img.transpose(1, 2, 0), datum.label]
    logger.warn("Caffe LMDB format doesn't store jpeg-compressed images, \
        it's not recommended due to its inferior performance.")
    return LMDBDataDecoder(lmdb_data, decoder)


class SVMLightData(RNGDataFlow):
    """ Read X,y from an SVMlight file, and produce [X_i, y_i] pairs. """

    def __init__(self, filename, shuffle=True):
        """
        Args:
            filename (str): input file
            shuffle (bool): shuffle the data
        """
        import sklearn.datasets  # noqa
        self.X, self.y = sklearn.datasets.load_svmlight_file(filename)
        self.X = np.asarray(self.X.todense())
        self.shuffle = shuffle

    def __len__(self):
        return len(self.y)

    def __iter__(self):
        idxs = np.arange(self.__len__())
        if self.shuffle:
            self.rng.shuffle(idxs)
        for id in idxs:
            yield [self.X[id, :], self.y[id]]


class TFRecordData(DataFlow):
    """
    Produce datapoints from a TFRecord file, assuming each record is
    serialized by :func:`serialize.dumps`.
    This class works with :func:`dftools.dump_dataflow_to_tfrecord`.
    """
    def __init__(self, path, size=None):
        """
        Args:
            path (str): path to the tfrecord file
            size (int): total number of records, because this metadata is not
                stored in the tfrecord file.
        """
        self._path = path
        self._size = int(size)

    def size(self):
        if self._size:
            return self._size
        return super(TFRecordData, self).size()

    def get_data(self):
        gen = tf.python_io.tf_record_iterator(self._path)
        for dp in gen:
            yield loads(dp)

class BinaryData(DataFlow):
    """
    Produce data points from a binary file. A format for each row is given 
    to the init so that we can decode the binary data one row at a time.
    """
    def __init__(self, path, n_val_per_row=None):
        self.dp_format = 'f' * n_val_per_row
        self.row_n_bytes = 4 * n_val_per_row
        self.path = path
        self.statinfo = os.stat(self.path)
        self._size = self.statinfo.st_size / 4 / n_val_per_row

    def size(self):
        return self._size 

    def get_data(self):
        with open(self.path, 'rb') as fin:
            for _ in range(self._size):
                yield [np.asarray(struct.unpack(self.dp_format, 
                    fin.read(self.row_n_bytes)),dtype=np.float32)]

class NPZData(DataFlow):

    def __init__(self, path, keys, select_range=None):
        self.path = path
        self.keys = keys
        self.data = np.load(path)

        if select_range is not None:
            start = select_range[0]
            end = select_range[1]
            self.data_temp={}
            for key in keys:
                self.data_temp[key] = self.data[key][start:end]
            self.data=self.data_temp

    def size(self):
        return self.data[self.keys[0]].shape[0] 

    def get_data(self):
        for i in range(self.size()):
            yield [self.data[key][i] for key in self.keys]


from ..utils.develop import create_dummy_class   # noqa
try:
    import h5py
except ImportError:
    HDF5Data = create_dummy_class('HDF5Data', 'h5py')   # noqa

try:
    import lmdb
except ImportError:
    for klass in ['LMDBData', 'LMDBDataDecoder', 'CaffeLMDB']:
        globals()[klass] = create_dummy_class(klass, 'lmdb')
